#!/usr/bin/python

import sys, os, signal, subprocess, errno, time

WARMUP = 3
DURATION = 15

CORES = file("/proc/cpuinfo").read().count("processor\t")

def usage():
    print >> sys.stderr, "Usage: %s count port" % sys.argv[0]
    sys.exit(2)

def setup():
    # Put ourselves in a new process group so we can broadcast signals
    os.setpgid(0, 0)
    # Ignore USR1 and USR2, which we broadcast
    signal.signal(signal.SIGUSR1, signal.SIG_IGN)
    signal.signal(signal.SIGUSR2, signal.SIG_IGN)
    # Get into the right directory
    os.chdir(os.path.dirname(sys.argv[0]))

def onlinecpu():
    import numa
    onlineCPUs = []
    for node in range(0,numa.get_max_node()+1):
        for cpu in sorted(numa.node_to_cpus(node)):
          onlineCPUs.append(cpu)
    return onlineCPUs

def startLoad(count, port):
    # Start smtpbm
    procs = []
    onlineCPUs = onlinecpu()
    for i in range(count):
        # smtpbm will exit if its parent dies, which makes our
        # exception handling much simpler (and often absent
        # altogether).
        procs.append(
            subprocess.Popen(["numactl", "-C", str(onlineCPUs[i % CORES]),
                              "./smtpbm", "127.0.0.1", str(port),
                              "%d@mosbench.org" % i, "mosbench@mosbench.org",
                              ],
                             stdout = subprocess.PIPE))
        # Don't hit the system too hard
        time.sleep(0.05)
    return procs

def startCounting(procs):
    print "Starting"
    sys.stdout.flush()
    os.kill(0, signal.SIGUSR1)

def timedWait(timeout):
    expired = [False]
    def onAlarm(signum, frame):
        expired[0] = True
    signal.signal(signal.SIGALRM, onAlarm)
    signal.alarm(timeout)
    try:
        return os.wait()
    except OSError, e:
        if e.errno != errno.EINTR or not expired[0]:
            raise
        return (0, 0)
    finally:
        signal.alarm(0)

def run(procs, duration):
    # Pause for the duration, paying attention to any smtpbm that dies
    # unexpectedly
    (pid, status) = timedWait(duration)
    if pid != 0:
        print >> sys.stderr, "smtpbm PID %d exited unexpectedly with %s" % \
            (pid, prettyWait(status))
        killall(procs, pid)
        sys.exit(1)

def stopCounting(procs):
    os.kill(0, signal.SIGUSR2)
    total = 0
    for p in procs:
        def onAlarm(signum, frame):
            raise IOError("Timed out reading from smtpbm %d" % p.pid)
        signal.signal(signal.SIGALRM, onAlarm)
        # If it takes longer than a second to get the output, we're
        # skewed and screwed.  This is mostly a safeguard to catch
        # misbehaving children.  The IOError will kill this process,
        # which will kill the children.
        signal.alarm(1)
        res = p.stdout.readline()
        signal.alarm(0)
        total += int(res.split()[0])
    print "Stopped"
    sys.stdout.flush()
    print "%d messages (%.2f messages/sec/core)" % \
        (total, float(total)/DURATION/CORES)

def killall(procs, butPid = -1):
    wait = [p for p in procs if p.pid != butPid and p.poll() == None]
    for p in wait:
        try:
            os.kill(p.pid, signal.SIGINT)
        except EnvironmentError, e:
            print >> sys.stderr, e
    for p in wait:
        try:
            p.wait()
        except EnvironmentError, e:
            print >> sys.stderr, e

def prettyWait(status):
    if os.WIFEXITED(status):
        return "status %d" % os.WEXITSTATUS(status)
    if os.WIFSIGNALED(status):
        return "signal %d" % os.WTERMSIG(status)
    return "unknown wait status %d" % status

if len(sys.argv) != 3 or not sys.argv[1].isdigit() or not sys.argv[2].isdigit():
    usage()

count, port = int(sys.argv[1]), int(sys.argv[2])

setup()
procs = startLoad(count, port)
time.sleep(WARMUP)
startCounting(procs)
run(procs, DURATION)
stopCounting(procs)
killall(procs)
